// Copyright 2023 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//	http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package actions

import (
	"context"
	"crypto"
	"errors"
	"fmt"
	"log"
	"os"
	"runtime"
	"sync"

	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/types/known/anypb"

	luciproto "go.chromium.org/luci/common/proto"
	"go.chromium.org/luci/common/system/environ"

	"go.chromium.org/luci/cipkg/core"
)

// Executor is the type of the Executor for action spec M. When the function is
// invoked, the action associated with the message spec should be executed.
// `out` is the output path for the artifacts generated by the executor.
type Executor[M proto.Message] func(ctx context.Context, msg M, out string) error

const envCipkgExec = "_CIPKG_EXEC_CMD"

// ReexecRegistry is the registry for actions that requires implemented in the
// binary itself.
//
// By default, NewReexecRegistry registers the following Reexec Executors:
//   - core.ActionURLFetch
//   - core.ActionFilesCopy
//   - core.ActionCIPDExport
//
// In order for a binary to work with reexec actions, you must call the
// .Intercept() function very early in your program's `main()`. This will
// intercept the invocation of this program and divert execution control to the
// registered Executor.
//
// Programs may register additional Executors in this registry using SetExecutor
// or MustSetExecutor functions from this package.
type ReexecRegistry struct {
	execs map[protoreflect.FullName]Executor[proto.Message]

	mu     sync.Mutex
	sealed bool
}

func NewReexecRegistry() *ReexecRegistry {
	m := &ReexecRegistry{
		execs: make(map[protoreflect.FullName]Executor[proto.Message]),
	}
	MustSetExecutor[*core.ActionURLFetch](m, ActionURLFetchExecutor)
	MustSetExecutor[*core.ActionFilesCopy](m, defaultFilesCopyExecutor.Execute)
	MustSetExecutor[*core.ActionCIPDExport](m, ActionCIPDExportExecutor)
	return m
}

var (
	ErrExecutorExisted      = errors.New("executor for the message type already existed")
	ErrReexecRegistrySealed = errors.New("executor can't be set after Intercept being called")
)

// SetExecutor set the executor for the action specification M.
// All executors must be set before calling .Intercept(). If there is a executor
// already registed for M, SetExecutor will return ErrExecutorExisted.
func SetExecutor[M proto.Message](r *ReexecRegistry, execFunc Executor[M]) error {
	r.mu.Lock()
	defer r.mu.Unlock()
	if r.sealed {
		return ErrReexecRegistrySealed
	}

	var msg M
	name := proto.MessageName(msg)
	if _, ok := r.execs[name]; ok {
		return ErrExecutorExisted
	}
	r.execs[name] = func(ctx context.Context, msg proto.Message, out string) error {
		return execFunc(ctx, msg.(M), out)
	}
	return nil
}

// MustSetExecutor set the executor for the action specification M similar to
// SetExecutor, but will panic if any error happened.
func MustSetExecutor[M proto.Message](r *ReexecRegistry, execFunc Executor[M]) {
	if err := SetExecutor[M](r, execFunc); err != nil {
		panic(err)
	}
}

// Intercept executes the registed executor and exit if _CIPKG_EXEC_CMD is
// found. This is REQUIRED for reexec to function properly and need to be
// executed after init() because embed fs or other resources may be
// registered in init().
// Any application using the framework must call the .Intercept() function very
// early in your program's `main()`. This will intercept the invocation of this
// program and divert execution control to the registered Executor.
// On windows, environment variable NoDefaultCurrentDirectoryInExePath will
// always be set to prevent searching binaries from current workding directory
// by default, which because of its relative nature, is forbidden by golang.
func (r *ReexecRegistry) Intercept(ctx context.Context) {
	if runtime.GOOS == "windows" {
		if err := os.Setenv("NoDefaultCurrentDirectoryInExePath", "1"); err != nil {
			panic(fmt.Sprintf("failed to set NoDefaultCurrentDirectoryInExePath on Windows: %s", err))
		}
	}
	r.interceptWithArgs(ctx, environ.System(), os.Args, os.Exit)
}

func (r *ReexecRegistry) interceptWithArgs(ctx context.Context, env environ.Env, args []string, exit func(int)) {
	r.mu.Lock()
	if !r.sealed {
		r.sealed = true
	}
	r.mu.Unlock()
	if !env.Remove(envCipkgExec) {
		return
	}

	if len(args) < 2 {
		panic(fmt.Sprintf("usage: cipkg-exec <proto>: insufficient args: %s", args))
	}

	var any anypb.Any
	if err := protojson.Unmarshal([]byte(args[1]), &any); err != nil {
		panic(fmt.Sprintf("failed to unmarshal anypb: %s, %s", err, args))
	}
	msg, err := any.UnmarshalNew()
	if err != nil {
		panic(fmt.Sprintf("failed to unmarshal proto from any: %s, %s", err, args))
	}
	f := r.execs[proto.MessageName(msg)]
	if f == nil {
		panic(fmt.Sprintf("unknown cipkg-exec command: %s", args))
	}

	if err := f(env.SetInCtx(ctx), msg, env.Get("out")); err != nil {
		log.Fatalln(err)
	}

	exit(0)
}

// reexecVersion is the globle reexec version which, if changed, will affect all
// derivations' FixedOutput generated from ReexecDerivation(...).
const reexecVersion = "v1"

// ReexecDerivation returns a derivation for re-executing the binary. It sets
// the FixedOutput using hash generated from action spec.
func ReexecDerivation(m proto.Message, hostEnv bool) (*core.Derivation, error) {
	self, err := os.Executable()
	if err != nil {
		return nil, err
	}

	m, err = anypb.New(m)
	if err != nil {
		return nil, err
	}
	b, err := protojson.Marshal(m)
	if err != nil {
		return nil, err
	}

	fixed, err := sha256String(m)
	if err != nil {
		return nil, err
	}

	env := environ.New(nil)
	if hostEnv {
		env = environ.System()
	}
	env.Set(envCipkgExec, "1")

	return &core.Derivation{
		Args:        []string{self, string(b)},
		Env:         env.Sorted(),
		FixedOutput: fixed,
	}, nil
}

func sha256String(m proto.Message) (string, error) {
	const algo = crypto.SHA256
	h := algo.New()
	if err := luciproto.StableHash(h, m); err != nil {
		return "", err
	}
	return fmt.Sprintf("%s%s:%x", reexecVersion, algo, h.Sum(nil)), nil
}
